package zmysql

import (
	"context"
	"database/sql"
	"fmt"
	"reflect"
	"time"

	"gitee.com/shijingzhao/go-frame/reflex"
	_ "github.com/go-sql-driver/mysql"
)

const DBTypeMysql = "mysql" // mysql数据库类型

const (
	maxIdleConn = 64
	maxOpenConn = 64
	maxLifetime = time.Minute
)

type Config struct {
	Source string `json:"source"`
}

type DB struct {
	db *sql.DB
}

// New 实例mysql
// 参数: source 数据库dsn root:J1ngzha0.com@tcp(127.0.0.1:3306)/db_user
// dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, database)
func New(config Config) *DB {
	db, err := sql.Open(DBTypeMysql, config.Source)
	if err != nil {
		panic(fmt.Sprintf("sql.Open Error:%s", err.Error()))
	}

	// 设置基础参数
	db.SetMaxIdleConns(maxIdleConn)
	db.SetMaxOpenConns(maxOpenConn)
	db.SetConnMaxLifetime(maxLifetime)

	// 建立链接
	err = db.Ping()
	if err != nil {
		cErr := db.Close()
		panic(fmt.Sprintf("db.Ping Error:%s, db.Close Error:%s", err.Error(), cErr.Error()))
	}

	return &DB{db: db}
}

// Exec 执行语句
// 例一      插入 db.Exec(context.TODO(), "insert into t_user(`user_id`) value (?)", 100000006)
// 例二      修改 db.Exec(context.TODO(), "update t_user set `gender` = `gender` + 1 where `user_id` = ?", 100000005)
// 例三      删除 db.Exec(context.TODO(), "delete from t_user where `user_id` = ? limit 1", 100000005)
//
// @params  ctx 上下文 s sql语句 args 参数
// @return  sql.Result
//
//	         sql.Result.LastInsertId
//	         sql.Result.RowsAffected
//		     error
func (db DB) Exec(ctx context.Context, s string, args ...interface{}) (sql.Result, error) {
	var ret sql.Result

	stmt, err := db.db.PrepareContext(ctx, s)
	if err != nil {
		return ret, err
	}
	defer func() {
		_ = stmt.Close()
	}()

	ret, err = stmt.ExecContext(ctx, args...)
	if err != nil {
		return ret, err
	}

	return ret, nil
}

// Get 查询单行结果
// 例一      db.Get(context.TODO(), &tuser, "select * from t_user")
// 例二      db.Get(context.TODO(), &sqlRet, "select count(*) as total from t_user") type SqlRet struct {Total int64 `json:"total"`}
// @params  ctx:上下文 dest:结构 s:sql语句 args:参数 [必要] 查询语句查询的字段必须是结构的所有字段
// @return  error [必要] 必须处理返回的错误, 以便明确未查询到数据还是数据查询出现错误或是正常查询
//
//	switch {
//	case err == sql.ErrNoRows:  // 未获取到数据
//	case err != nil:            // 错误描述
//	default:                    // 查询数据
//	}
func (db DB) Get(ctx context.Context, dest interface{}, s string, args ...interface{}) error {
	stmt, err := db.db.PrepareContext(ctx, s)
	if err != nil {
		return err
	}
	defer func() {
		_ = stmt.Close()
	}()

	// 校验是否是结构类型
	v := reflect.ValueOf(dest).Elem()
	if v.Kind() != reflect.Struct {
		return fmt.Errorf("应为%s,但得到%s", reflect.Struct, v.Kind())
	}

	values := make([]interface{}, v.NumField())
	reflex.ValuesPtr(v, values)

	err = stmt.QueryRowContext(ctx, args...).Scan(values...)
	return err
}

// Query 查询结果集
// 例一      db.Query(context.TODO(), &tusers, "select * from t_user")
// 例二      db.Query(context.TODO(), &tusers, "select * from t_user where user_id = ?", 100000000)
// @params  ctx:上下文 dest:结果集结构 s:sql语句 args:参数
// @return  error
func (db DB) Query(ctx context.Context, dest interface{}, s string, args ...interface{}) error {
	stmt, err := db.db.PrepareContext(ctx, s)
	if err != nil {
		return err
	}
	defer func() {
		_ = stmt.Close()
	}()

	rows, err := stmt.QueryContext(ctx, args...)
	if err != nil {
		return err
	}
	defer func() {
		_ = rows.Close()
	}()

	// 获取切片反射类型
	value := reflect.ValueOf(dest)
	if value.Kind() != reflect.Ptr {
		return fmt.Errorf("应为%s,但得到%s", reflect.Ptr, value.Kind())
	}
	direct := reflect.Indirect(value)

	// 基础类型
	slice := reflex.Deref(value.Type())
	if slice.Kind() != reflect.Slice {
		return fmt.Errorf("应为%s,但得到%s", reflect.Slice, slice.Kind())
	}
	base := reflex.Deref(slice.Elem())

	// 是否是指针
	isPtr := slice.Elem().Kind() == reflect.Ptr

	// 列字段集为scan赋值准备
	columns, err := rows.Columns()
	if err != nil {
		return err
	}
	values := make([]interface{}, len(columns))

	// 获取数据集
	for rows.Next() {
		vp := reflect.New(base)
		v := reflect.Indirect(vp)

		// 结果指针
		reflex.ValuesPtr(v, values)

		// 获取数据
		err = rows.Scan(values...)
		if err != nil {
			return err
		}

		if isPtr {
			direct.Set(reflect.Append(direct, vp))
		} else {
			direct.Set(reflect.Append(direct, v))
		}
	}

	return rows.Err()
}
